library(tidyverse)
library(ggrepel)
source("src/media.R")

theme_sm <- function() {
  theme_bw() +
    theme(
      #panel.background = element_rect(fill='transparent'), #transparent panel bg
      plot.background = element_rect(fill='transparent', color='white'), #transparent plot bg
      #panel.grid.major = element_blank(), #remove major gridlines
      #panel.grid.minor = element_blank(), #remove minor gridlines
      legend.background = element_rect(fill='transparent'), #transparent legend bg
      legend.box.background = element_rect(fill='transparent'), #transparent legend panel
      text = element_text(size=20),
      plot.caption = element_text(size=15)
    )
}

load("data/stable_topic_model.RData")
#load("data/multiple_topic_models.RData")
#target_model <- models$runout[[6]]

# # re-order the theta matrix so that the RT-heavy topics have lower numbers
# remapping <- order(colMeans(target_model$theta[which(out$meta$publisher == "RT"), ]), decreasing = TRUE)
# theta <- target_model$theta[, remapping]

# # TOPICS TO MERGE: 
# # 34 - 57 - 87
# # 58 - 69

# if (ncol(theta) == 89) {
#     theta <- cbind(theta, (theta[, 34] + theta[, 57] + theta[,87]), (theta[, 58] + theta[, 69]))
#     theta <- theta[, -c(34, 57, 87, 58, 69)]
#     remapping <- order(colMeans(theta[which(out$meta$publisher == "RT"), ]), decreasing = TRUE)
# }

first_remapping <- order(colMeans(target_model$theta[which(out$meta$publisher == "RT"), ]), decreasing = TRUE)
theta <- target_model$theta[, first_remapping]

# TOPICS TO MERGE: 
# 34 - 57 - 87
# 58 - 69
assertthat::assert_that(ncol(theta) == 89)
theta <- cbind(theta, (theta[, 34] + theta[, 57] + theta[,87]), (theta[, 58] + theta[, 69]))
theta <- theta[, -c(34, 57, 87, 58, 69)]
second_remapping <- order(colMeans(theta[which(out$meta$publisher == "RT"), ]), decreasing = TRUE)
theta <- theta[, second_remapping]
assertthat::assert_that(ncol(theta) == 86)

out$meta$topic <- max.col(theta)
out$meta$topic_prob <- apply(theta, 1, max)

assign_topic_for_word <- function(word, model) {
  word_idx = which(target_model$vocab == word)
  topic = NA
  if (length(word_idx) == 1) {
    topic = which.max(model$beta$logbeta[[1]][, word_idx])
  }
  return(topic)
}


df <- as_tibble(out$meta)
df$topic_number <- df$topic

theta_df <- as_tibble(theta, .name_repair = ~str_c("prop_topic_", 1:ncol(theta)))
df <- bind_cols(df, theta_df)

# assign a one-word label
words <- stm::labelTopics(target_model)$frex[, 1:3]
labels <- character(dim(words)[1])
for (i in 1:length(labels)) {
  labels[i] <- str_flatten_comma(words[i,])
}

extra_data <- arrow::read_parquet("/net/data/newstopics/extra_results.parquet")
extra_data$topic <- extra_data$topic_number

#extra_data <- tibble()
# have to coerce back to numeric for the join
df$day <- as.numeric(as.character(df$day))
df <- bind_rows(df, extra_data)
df$day <- factor(df$day, levels=min(df$day):max(df$day), ordered=TRUE)
df$label <-labels[df$topic_number]

# create alternate day measure
df <- df |>
  mutate(alt_day = floor(as.numeric(creation_date - lubridate::ymd("20170101", tz="Europe/Moscow"), "days")))
df$alt_day <- factor(df$alt_day, levels=min(df$alt_day):max(df$alt_day), ordered=TRUE)

df <- filter(df, publisher != "BBC")
df <- filter(df, publisher != "Guardian")

df <- mutate(df, publisher = fct_drop(publisher))

aggdf <- df |>
  #group_by(publisher, topic_number, topic, label) |>
  group_by(publisher, topic_number, topic, label) |>
  summarize(n = n(),
            engagement = sum(fb_data.total_engagement_count)) |>
  ungroup()

plot_df <- aggdf |>
  group_by(publisher) |>
  mutate(n_pct = n / sum(n),
         engagement_pct = engagement / sum(engagement))


topic_distplot <- function(df,
                           with_engagement = TRUE,
                           title = "Topic distribution for selected publishers, 2017-2021") {

  p = ggplot(df, aes(x = topic_number, group=publisher)) +
    geom_col(aes(y=n_pct), fill = NA, color="#e66101", linewidth=1) +
    facet_wrap(vars(publisher),
               ncol = 1,
               scales="free") +
    labs(x = 'Topic',
         y = 'Percent of Articles',
         title = title,
         fill = "Topic",
         caption = paste(
           c('Excluding sports and entertainment news, minimum 10 engagements on Facebook.',
             'Only top 40 topics covered by RT shown'),
           collapse="\n")) +
    theme_sm() +
    scale_y_continuous(labels = scales::percent) +
    coord_cartesian(xlim = c(0, 40.5))

  if(with_engagement) {
    p = p + geom_col(aes(y=engagement_pct), alpha=0.8, fill='#5e3c99') +
      labs(subtitle = 'Rectangles show articles, solid bars show engagement')
  }

  return(p)
}
p0 <- topic_distplot(filter(plot_df, publisher == "RT"),
                     with_engagement = FALSE,
                     title="Topic distribution for Russia Today, 2017-2021")
p1 <- topic_distplot(filter(plot_df, publisher =="RT"), title="Topic distribution for Russia Today, 2017-2021")
p2 <- topic_distplot(filter(plot_df, publisher %in% c("RT", "Sputnik")))
p3 <- topic_distplot(filter(plot_df, publisher %in% c("RT", "Breitbart", "Gateway Pundit")))
p4 <- topic_distplot(filter(plot_df, publisher %in% c("RT", "New York Times", "Washington Post")))
p5 <- topic_distplot(filter(plot_df, publisher %in% c("RT", "Vox", "Slate")))

ggsave("results/figures/build_dist_plot_0.png", p0, width=20, height=8, bg = "transparent", dpi=300)
ggsave("results/figures/build_dist_plot_1.png", p1, width=20, height=8, bg = "transparent", dpi=300)
ggsave("results/figures/build_dist_plot_2.png", p2, width=20, height=8, bg = "transparent", dpi=300)
ggsave("results/figures/build_dist_plot_3.png", p3, width=20, height=8, bg = "transparent", dpi=300)
ggsave("results/figures/build_dist_plot_4.png", p4, width=20, height=8, bg = "transparent", dpi=300)
ggsave("results/figures/build_dist_plot_5.png", p5, width=20, height=8, bg = "transparent", dpi=300)


n_df <- aggdf |>
  select(publisher, topic_number, n) |>
  pivot_wider(names_from=topic_number, values_from=n, values_fill = 0)
n_mat <- as.matrix(n_df[, -1])
rownames(n_mat) <- n_df[[1]]
n_mat <- n_mat / rowSums(n_mat)

engagement_df <- aggdf |>
  select(publisher, topic_number, engagement) |>
  pivot_wider(names_from=topic_number, values_from=engagement, values_fill = 0)
engagement_mat <- as.matrix(engagement_df[, -1])
rownames(engagement_mat) <- engagement_df[[1]]
engagement_mat <- engagement_mat / rowSums(engagement_mat)

kl_divergence <- function(p, q) {
  p <- p/sum(p) # Normalize p
  q <- q/sum(q) # Normalize q
  p[p == 0] <- 1e-10 # Avoid division by zero
  q[q == 0] <- 1e-10 # Avoid division by zero
  return(sum(p * log(p/q)))
}
jensen_shannon_divergence <- function(p, q) {
  avg <- (p + q) / 2 # Average distribution
  jsd <- (kl_divergence(p, avg) + kl_divergence(q, avg)) / 2
  return(jsd)
}

n_jsds <- numeric(dim(n_mat)[1])
engagement_jsds <- numeric(dim(engagement_mat)[1])

for (i in 1:dim(n_mat)[1]) {
  n_jsds[i] <- jensen_shannon_divergence(n_mat["RT",], n_mat[i,])
  engagement_jsds[i] <- jensen_shannon_divergence(engagement_mat["RT",], engagement_mat[i,])
}
names(n_jsds) <- rownames(n_mat)
names(engagement_jsds) <- rownames(engagement_mat)

# (p1 <-  aggdf |>
#     ggplot(aes(x = topic_number, y=engagement, fill=topic)) +
#     geom_bar(stat = "identity") +
#     facet_wrap(vars(publisher),
#                ncol = 1,
#                scales="free") +
#     labs(x = 'Topic',
#          y = 'Number of Articles',
#          title = 'Topic distribution for select domains, 2017-2021',
#          subtitle = 'Articles weighted by engagement',
#          fill = "Topic",
#          caption = 'Excluding sports and entertainment news, minimum 10 engagements on Facebook.') +
#     theme_sm() +
#     scale_fill_brewer(type="qual", palette=3, na.value="gray50"))

# (p2 <- aggdf |>
#     ggplot(aes(x = topic_number, y = n, fill = topic)) +
#     geom_bar(stat = "identity") +
#     facet_wrap(vars(publisher),
#                ncol = 1,
#                scales="free") +
#     labs(x = 'Topic',
#          y = 'Number of Articles',
#          title = 'Topic distribution for select domains, 2017-2021',
#          subtitle = 'Articles not weighted by engagement',
#          fill = "Topic",
#          caption = 'Excluding sports and etertainment news, minimum 10 engagements on Facebook.') +
#     theme_sm() +
#     scale_fill_brewer(type="qual", palette=3, na.value="gray50"))

# ggsave("results/figures/topic_distribution_msm_vs_russia_weighted.png", p1, width=20, height=8, bg = "transparent", dpi=300)
# ggsave("results/figures/topic_distribution_msm_vs_russia_unweighted.png", p2, width=20, height=8, bg = "transparent", dpi=300)


# posts_per_publisher_on_topic <- function(topic) {
#   ggplot(df |> filter(topic == topic),
#          aes(x = creation_date,
#              y = publisher,
#              size = log10(fb_data.total_engagement_count))) +
#     geom_point(alpha = 0.4) +
#     scale_size(range = c(0.2, 4), guide = NULL) +
#     theme_sm()  +
#     labs(x = "Date",
#          y = "Publisher",
#          size = "Engagement",
#          title = glue::glue("Posts per publisher on {topic}"),
#          caption = "Dashed lines separate groups of publishers. Points sized by logged Facebook engagements") +
#     geom_hline(aes(yintercept = 19.5), alpha = 0.6, linetype='dashed', linewidth = 0.2) +
#     geom_hline(aes(yintercept = 32.5), alpha = 0.6, linetype='dashed', linewidth = 0.2) +
#     geom_hline(aes(yintercept = 77.5), alpha = 0.6, linetype='dashed', linewidth = 0.2)
# }

# (p3 <- posts_per_publisher_on_topic("Syria"))
# (p4 <- posts_per_publisher_on_topic("Immigration"))
# (p5 <- posts_per_publisher_on_topic("Russia"))
# (p6 <- posts_per_publisher_on_topic("Coronavirus"))
# (p7 <- posts_per_publisher_on_topic("Ukraine"))

# ggsave("results/figures/posts_per_publisher_syria.png", p3, width=20, height=10, bg = "transparent", dpi=300)
# ggsave("results/figures/posts_per_publisher_immigration.png", p4, width=20, height=10, bg = "transparent", dpi=300)
# ggsave("results/figures/posts_per_publisher_russia.png", p5, width=20, height=10, bg = "transparent", dpi=300)
# ggsave("results/figures/posts_per_publisher_coronavirus.png", p6, width=20, height=10, bg = "transparent", dpi=300)

# ggsave("results/figures/posts_per_publisher_ukraine.png", p7, width=20, height=10, bg = "transparent", dpi=300)

# df <- mutate(df, russia = as.integer(str_detect(str_to_lower(headline_and_blurb), "russia")))
# df <- mutate(df, ukraine = as.integer(str_detect(str_to_lower(headline_and_blurb), "ukrain")))
# df <- mutate(df, hunter = as.integer(str_detect(str_to_lower(headline_and_blurb), "hunter")))
# df <- mutate(df, crimea = as.integer(str_detect(str_to_lower(headline_and_blurb), "crimea")))


# p5_alt <- ggplot(df |> filter(ukraine == 1),
#                  aes(x = creation_date,
#                      y = publisher,
#                      size = log(fb_data.total_engagement_count))) +
#   geom_point(alpha = 0.4) +
#   scale_size(range = c(0.2, 5), guide = NULL) +
#   theme_sm()  +
#   labs(x = "Date",
#        y = "Publisher",
#        size = "Engagement",
#        title = glue::glue("Posts per publisher on Russia"),
#        subtitle = "Using keyword only, not topic",
#        caption = "Dashed lines separate groups of publishers. Points sized by logged Facebook engagements") +
#   geom_hline(aes(yintercept = 19.5), alpha = 0.6, linetype='dashed', linewidth = 0.2) +
#   geom_hline(aes(yintercept = 32.5), alpha = 0.6, linetype='dashed', linewidth = 0.2) +
#   geom_hline(aes(yintercept = 78.5), alpha = 0.6, linetype='dashed', linewidth = 0.2)

# ggsave("results/figures/posts_per_publisher_ukraine_alt.png", p5_alt, width=20, height=10, bg = "transparent", dpi=300)


gini_p <- aggdf |>
  group_by(publisher) |>
  summarize(n_articles = sum(n),
            n_gini = ineq::ineq(n),
            engagement_gini = ineq::ineq(engagement)) |>
  mutate(type = code_categories(publisher)) |>
  arrange(desc(engagement_gini)) |>
  ggplot(aes(x=n_gini,
             y=engagement_gini,
             size=n_articles,
             color=type)) +
  geom_point() +
  geom_text_repel(aes(label=publisher), max.overlaps = 7, seed = 20) +
  scale_color_brewer(type="qual", palette=6, na.value="gray50") +
  theme_sm() +
  labs(title = "Concentration of topic engagement and coverage",
       x='Gini coefficient, article coverage',
       y='Gini coefficient, article engagement') +
  geom_abline(slope=1, intercept=0, linetype='dashed', linewidth=0.2, alpha=0.6)
ggsave("results/figures/gini_plot.png", gini_p, width=10, height=6, bg = "white", dpi=300)



dist_p <- tibble(article_jsd = n_jsds, engagement_jsd = engagement_jsds, publisher=names(n_jsds)) |>
  filter(publisher != "RT") |>
  mutate(type = code_categories(publisher)) |>
  ggplot(aes(x=article_jsd,
             y=engagement_jsd,
             color=type)) +
  theme_sm() +
  geom_point() +
 # coord_cartesian(xlim=c(0, 0.5), ylim=c(0, 0.5)) +
  geom_smooth(method="lm", color="black", alpha=0.8, linetype='dashed', se=FALSE) +
  geom_text_repel(aes(label=publisher), max.overlaps = 7, seed = 20) +
  labs(title = "Dissimilarity of publishers to RT", caption = "Lower values mean topic distribution is more similar to RT",
       x = "Articles", y = "Engagement", color = "Publisher Type") +
  scale_color_brewer(type="qual", palette=6, na.value="gray50")
ggsave("results/figures/distance_scatter.png", dist_p, width=20, height=10, bg = "transparent", dpi=300)


# VAR data
# df |>
#   group_by(day, publisher, grouping, topic_number) |>
#   mutate(
#     num_articles_published = n(),
#     total_engagement = sum(fb_data.total_engagement_count)) |>
#   write_tsv("/net/data/newstopics/var_model_data_full.tsv.gz")

df |>
  group_by(day, publisher, grouping, .drop = FALSE) |>
  summarize_at(vars(starts_with("prop_")), mean) |>
  pivot_longer(prop_topic_1:last_col()) |>
  select(-grouping) |>
  pivot_wider(names_from = publisher,
              values_from = value) |>
  mutate(name = str_extract(name, "[a-z_]+([\\d]+)", group=1)) |>
  rename(topic = name) |>
  write_tsv("/net/data/newstopics/alt_var_model_data_full.tsv.gz")


df |>
  group_by(alt_day, publisher, grouping, .drop = FALSE) |>
  summarize_at(vars(starts_with("prop_")), mean) |>
  pivot_longer(prop_topic_1:last_col()) |>
  select(-grouping) |>
  pivot_wider(names_from = publisher,
              values_from = value) |>
  mutate(name = str_extract(name, "[a-z_]+([\\d]+)", group=1)) |>
  rename(topic = name) |>
  write_tsv("/net/data/newstopics/alt_var_model_data_moscow_time_full.tsv.gz")
